import gym
import highway_env

from stable_baselines3 import A2C
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
from torch.utils.tensorboard import SummaryWriter
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

# 创建环境

env = gym.make('roundabout-v0')
env = Monitor(env, './logs')  # 日志文件将保存在./logs目录下
# 创建训练模型
model = A2C(policy="MlpPolicy",
            env=env, verbose=1,
            learning_rate=0.01,
            n_steps=10,
            gamma=0.95)

# 训练模型
model.learn(total_timesteps=int(3e4))
# 评估模型
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)
# 保存模型
model.save("roundabout_A2C")
# 加载模型
model = A2C.load("roundabout_A2C")
# 使用模型进行预测

number_of_collisions = 0
T = 1
for f in range(50):
  done = truncated = None
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs)
    # obs, reward, done, truncated, info = env.step(action)
    obs, reward, done, truncated, info = env.step(action.item(0))


    # print(reward)
    #print(info)
    # print(obs)
    # input("Press Enter to continue...")

    print(info['crashed'])

    cur_frame = env.render(mode="rgb_array")
    # out.write(cur_frame)

    #print(action)
    #print(obs)
    #print(info)
    #print(reward)
    if info.get('crashed'):
        number_of_collisions += 1
    env.render()
    cur_frame = env.render(mode="rgb_array")
    # out.write(cur_frame)
    print('crashrate is '+str(float(number_of_collisions)/T)+' and T is'+str(T))
    T+=1


print('number_of_collisions is:', number_of_collisions)
print('DONE')
